{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "emb_value tensor([[0., 1., 2.],\n",
      "        [3., 4., 5.]], grad_fn=<EmbeddingBackward0>)\n",
      "loss tensor(15., grad_fn=<SumBackward1>)\n",
      "tensor(indices=tensor([[0, 1]]),\n",
      "       values=tensor([[1., 1., 1.],\n",
      "                      [1., 1., 1.]]),\n",
      "       size=(3, 3), nnz=2, layout=torch.sparse_coo)\n",
      "weight tensor([[-1.0000e+00,  3.5763e-07,  1.0000e+00],\n",
      "        [ 2.0000e+00,  3.0000e+00,  4.0000e+00],\n",
      "        [ 6.0000e+00,  7.0000e+00,  8.0000e+00]], requires_grad=True)\n",
      "tensor(indices=tensor([[0, 1]]),\n",
      "       values=tensor([[1., 1., 1.],\n",
      "                      [1., 1., 1.]]),\n",
      "       size=(3, 3), nnz=2, layout=torch.sparse_coo)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "weight = torch.arange(9).reshape(3, 3).float().requires_grad_()\n",
    "\n",
    "\n",
    "\n",
    "input_keys = torch.tensor([0, 1,],).long()\n",
    "\n",
    "emb_value = F.embedding(input_keys, weight, sparse=True, padding_idx=None, scale_grad_by_freq=False,)\n",
    "\n",
    "print(\"emb_value\", emb_value)\n",
    "\n",
    "loss = emb_value.sum(-1).sum(-1)\n",
    "\n",
    "print(\"loss\", loss)\n",
    "loss.backward()\n",
    "\n",
    "print(weight.grad)\n",
    "\n",
    "# opt = optim.SparseAdam(emb_layer.parameters(), lr= 1)\n",
    "opt = optim.SparseAdam([weight], lr= 1)\n",
    "opt.step()\n",
    "\n",
    "# for each in emb_layer.parameters():\n",
    "#     print('---emb_layer.parameters---')\n",
    "#     print(each)\n",
    "#     print('---emb_layer.parameters.grad---')\n",
    "#     print(each.grad)\n",
    "#     print(weight.grad)\n",
    "\n",
    "print(\"weight\", weight)\n",
    "print(weight.grad)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
